# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import contextlib
from dataclasses import dataclass
import os
from pathlib import Path
import torch.distributed
import logging

from torch.profiler.profiler import profile
import xformers.profiler
from xformers.profiler import (
    MemSnapshotsProfiler,
    PyTorchProfiler,
)

from lingua.distributed import get_is_master

import wandb


@dataclass
class ProfilerArgs:
    run: bool = False
    trace_folder: str = "profiling"
    mem_warmup: int = 100
    mem_steps: int = 2
    profile_warmup: int = 102
    profile_steps: int = 2


logger = logging.getLogger()


def perfetto_to_html(json_file, html_file):
    import viztracer
    import gzip
    import string

    root = os.path.dirname(viztracer.__file__)
    sub = {}
    json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file)
    with open(
        os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8"
    ) as f:
        tmpl = f.read()
    with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f:
        sub["trace_viewer_full"] = f.read()
    with json_file as j:
        content = j.read()
        if isinstance(content, bytes):
            content = content.decode("utf-8")
        sub["json_data"] = content.replace("</script>", "<\\/script>")  # type: ignore
    with open(html_file, "w+", encoding="utf-8") as output_file:
        output_file.write(string.Template(tmpl).substitute(sub))


class PyTorchProfilerWandb(PyTorchProfiler):
    def __init__(self, main_profiler) -> None:
        self.main_profiler = main_profiler
        self.num_steps = 0
        self.pytorch_profiler = torch.profiler.profile(
            on_trace_ready=self._on_trace,
            profile_memory=True,
            record_shapes=True,
            # With stack gives huge profile traces
            # and bugs out because of some non ascii
            # character somewhere in pytorch
            with_stack=False,
            with_flops=True,
            activities=self.ACTIVITIES,
        )

    def _analyze_trace(self, prof: profile):
        logger.info("Begin analyze trace")
        super()._analyze_trace(prof)
        logger.info("End analyze trace")

    def _on_trace(self, prof: torch.profiler.profiler.profile) -> None:
        super()._on_trace(prof)
        if get_is_master() and wandb.run is not None:
            filename = list(
                Path(self.main_profiler.output_dir).glob(
                    "profile_CPU_CUDA*/*.pt.trace.json*"
                )
            )[0]
            html_path = str(filename).replace(".json", ".html")
            perfetto_to_html(filename, html_path)
            wandb.log({"profile_trace": wandb.Html(html_path)})


class MemSnapshotsProfilerWandb(MemSnapshotsProfiler):
    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)
        if get_is_master() and wandb.run is not None:
            filename = list(
                Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html")
            )[0]
            wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)})


@contextlib.contextmanager
def maybe_run_profiler(dump_dir, module, config: ProfilerArgs):
    # get user defined profiler settings

    if config.run:
        trace_dir = os.path.join(dump_dir, config.trace_folder)

        logger.info(f"Profiling active.  Traces will be saved at {trace_dir}")

        if get_is_master() and not os.path.exists(trace_dir):
            os.makedirs(trace_dir)
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

        with xformers.profiler.profile(
            output_dir=trace_dir,
            module=module,
            schedule=[
                (
                    MemSnapshotsProfilerWandb,
                    config.mem_warmup,
                    config.mem_warmup + config.mem_steps,
                ),
                (
                    PyTorchProfilerWandb,
                    config.profile_warmup,
                    config.profile_warmup + config.profile_steps,
                ),
            ],
        ) as profiler:
            yield profiler

    else:
        torch_profiler = contextlib.nullcontext()
        yield None
